{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上一节中，我们简单介绍了nanoGPT的基本方式，但是我们也能看出这个GPT过于简陋，其生成效果急需进一步的提高\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 首先是编码，嵌入方式的改进\n",
    "- 之前我们采用的是简单的一一对应的方式\n",
    "    - 对于只有26个大写和26个小写的英文字符来说，这样似乎还算合理，因为只是把50多个或者60多个字符按照顺序去编码为对应的数字而已"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 例如，OpenAI 在之前的GPT-2，GPT-3，GPT-4系列中使用的是其\n",
    "发布的 tiktoken 库，而 Google 也有其自己的分词/编码工具 SentencePiece，他们只是\n",
    "不同的方式，但做的都是将“完整的句子转化成整数向量编码”这样的一件事情。例\n",
    "如，我们现在可以利用 tiktoken 库来十分方便地调用 GPT-2 中所训练的 tokenizer，从\n",
    "而实现编解码过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[37046, 21043, 10890, 247, 162, 224, 253, 35894]\n",
      "我是孙悟空\n"
     ]
    }
   ],
   "source": [
    "# Way2\n",
    "import tiktoken\n",
    "enc = tiktoken.get_encoding(\"cl100k_base\")\n",
    "# enc = tiktoken.get_encoding(\"gpt2\")\n",
    "print(enc.encode(\"我是孙悟空\"))\n",
    "print(enc.decode(enc.encode(\"我是孙悟空\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之前经过encoder之后的数量不会改变"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "489540"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 还是跟之前一样，我们先进行文本的读取\n",
    "with open('../data/Xiyou.txt', 'r', encoding='utf-8') as f:\n",
    "    text = f.read()\n",
    "\n",
    "n = int(0.5*len(text)) # 前90%都是训练集，后10%都是测试集\n",
    "text = text[:n]\n",
    "\n",
    "  # 对文本进行编码\n",
    "len(enc.encode(text))   # 获取编码之后的长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们检查一下这样是否work、以及这样是否可以提升性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "# 开始划分训练集和测试集\n",
    "code = enc.encode(text)\n",
    "data = torch.tensor(code, dtype=torch.long)   # Way 1 \n",
    "\n",
    "vocab_size = len(code)\n",
    "n = int(0.9*len(data)) # 前90%都是训练集，后10%都是测试集\n",
    "train_data = data[:n]\n",
    "val_data = data[n:]\n",
    "\n",
    "# 进行参数的设置\n",
    "device = 'cpu'  # 模型运行的设备\n",
    "block_size = 16  # 每个单元的最大长度\n",
    "batch_size = 32 # 同时运行的批次大小\n",
    "learning_rate = 0.3\n",
    "max_iters = 1000\n",
    "eval_interval = 300  # 对当前模型的运行结果进行评估的epoch数量\n",
    "eval_iters = 200\n",
    "\n",
    "# 每次从数据集中获取x和y\n",
    "def get_batch(split):\n",
    "    # generate a small batch of data of inputs x and targets y\n",
    "    data = train_data if split == 'train' else val_data\n",
    "    ix = torch.randint(len(data) - block_size, (batch_size,))\n",
    "    x = torch.stack([data[i:i+block_size] for i in ix])\n",
    "    y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
    "    return x, y\n",
    "\n",
    "class BLM(nn.Module):\n",
    "    def __init__(self,vocab_size):\n",
    "        super().__init__()\n",
    "        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)\n",
    "    \n",
    "    def forward(self,idx,targets = None):\n",
    "        # 这里的self，就是我们之前的x，target就是之前的y\n",
    "        logits = self.token_embedding_table(idx) # (B,T)  -> (B,T,C)  # 这里我们通过Embedding操作直接得到预测分数\n",
    "        # 这里的预测分数过程与二分类或者多分类的分数是大致相同的\n",
    "\n",
    "        \n",
    "        if targets is None:\n",
    "            loss = None\n",
    "        else:   \n",
    "            B, T, C = logits.shape\n",
    "            logits = logits.view(B*T, C)\n",
    "            targets = targets.view(B*T)    # 这里我们调整一下形状，以符合torch的交叉熵损失函数对于输入的变量的要求\n",
    "            loss = F.cross_entropy(logits, targets)\n",
    "\n",
    "        return logits, loss\n",
    "\n",
    "    def generate(self, idx, max_new_tokens):\n",
    "        '''\n",
    "        idx 是现在的输入的(B, T) array of indices in the current context\n",
    "        max_new_tokens 是产生的最大的tokens数量\n",
    "        '''\n",
    "\n",
    "        for _ in range(max_new_tokens):\n",
    "            # 得到预测的结果\n",
    "            logits, loss = self(idx)\n",
    "            \n",
    "            # 只关注最后一个的预测\n",
    "            logits = logits[:, -1, :] # becomes (B, C)\n",
    "            # 对概率值应用softmax\n",
    "            probs = F.softmax(logits, dim=-1) # (B, C)\n",
    "            # 对input的每一行做n_samples次取值，输出的张量是每一次取值时input张量对应行的下标，也即找到概率值输出最大的下标，也对应着最大的编码\n",
    "            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
    "            # 将新产生的编码加入到之前的编码中，形成新的编码\n",
    "            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
    "        return idx "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。\n",
      "\u001b[1;31m请查看单元格中的代码，以确定故障的可能原因。\n",
      "\u001b[1;31m单击<a href='https://aka.ms/vscodeJupyterKernelCrash'>此处</a>了解详细信息。\n",
      "\u001b[1;31m有关更多详细信息，请查看 Jupyter <a href='command:jupyter.viewOutput'>log</a>。"
     ]
    }
   ],
   "source": [
    "# 创建模型\n",
    "model = BLM(vocab_size)\n",
    "m = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 简单的写下\n",
    "optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)\n",
    "for steps in range(1000): # 随着迭代次数的增长，获得的效果会不断变好\n",
    "\n",
    "    xb, yb = get_batch('train')\n",
    "\n",
    "    logits, loss = m(xb, yb)  # 用于评估损失\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print(epoch)\n",
    "    print(loss.item())\n",
    "\n",
    "print(loss.item())\n",
    "\n",
    "optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(enc.decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
